import copy
import json

import numpy as np
import torch

from Causal_Partial_Mnist.Find_CF_Synthetic_Distribution_Mnist import get_intv_dist, get_synthetic_dist
from Causal_Partial_Mnist.RejectionSampling_Optimized import rejection_sampling_optimized
from Causal_Partial_Mnist.True_Counterfactuals_Mnist import get_cf_dist
from ModularUtils.ControllerConstants import map_fill_to_discrete, map_dictfill_to_discrete
from ModularUtils.ControllerModel import get_generated_labels
from ModularUtils.FunctionsConstant import getdoKey
from ModularUtils.FunctionsDistribution import compare_conditionals_within, calculate_TVD, match_with_true_dist
from ModularUtils.FunctionsTraining import save_results


def compare_conditionals(Exp, label_generators, obs_real_dataset, vars):
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, Exp.label_names,
                                                 Exp.Synthetic_Sample_Size, vars)
    y_dims = sum([Exp.label_dim[lb][feat] for lb in Exp.label_names for feat in Exp.features])
    ret = list(generated_labels_dict.values())
    generated_labels_full = torch.cat(ret, 1).view(-1, y_dims)

    dims_list = [Exp.label_dim[lb][feat] for lb in Exp.label_names for feat in Exp.features]
    generated_labels_full = map_fill_to_discrete(Exp, generated_labels_full,
                                                 dims_list).detach().cpu().numpy().astype(int)

    # genZ_doX= get_generated_labels(Exp, label_generators, {}, {}, {"X":0}, ["Z"], Exp.Synthetic_Sample_Size, vars,needDataset=False)
    # genY_doX = get_generated_labels(Exp, label_generators, {}, {}, {"Z":genZ_doX}, ["Y"], Exp.Synthetic_Sample_Size, vars,needDataset=False)
    #
    # genY_doX_disc = map_fill_to_discrete(Exp, genY_doX).detach().cpu().numpy().astype(int)
    #
    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, ["Y"], {"X":0}, load_scm=1)
    # doX_tvd = match_with_true_dist(Exp, ["Y"], genY_doX_disc, true_dist_dict, doPrint=False)
    # print("P(Y|do(X=0)", doX_tvd)

    # print("P(Y|do(X))")
    # ret1=compare_conditionals_within(Exp, generated_intv_full, ["Y"], ["X"],
    #                             doPrint=False)
    #
    # ret2=compare_conditionals_within(Exp, intv_real_dataset.detach().cpu().numpy().astype(int), ["Y"], ["X"],
    #                             doPrint=False)
    #
    # print("TVD",calculate_TVD(ret1, ret2, doPrint=True))

    for feat in Exp.features:
        mech_tvd = 0
        for lbid, label in enumerate(Exp.label_names):

            conditons = copy.deepcopy(Exp.train_mech_list[lbid]["compare"])
            conditons.remove(label)
            pstr = feat + ":P(" + label + "|" + str(conditons) + ")"
            print(pstr)
            # ret1 = compare_conditionals_within(Exp, generated_labels_full, feat, [label],  Exp.Observed_DAG[label], doPrint=False)
            ret1 = compare_conditionals_within(Exp, generated_labels_full, feat, [label], conditons, doPrint=False)

            ret2 = compare_conditionals_within(Exp, obs_real_dataset.detach().cpu().numpy().astype(int), feat, [label],
                                               conditons, doPrint=False)

            # div=Exp.label_dim**len(Exp.Observed_DAG[label])
            div = np.prod([Exp.label_dim[lb][feat] for lb in conditons])
            tvd = calculate_TVD(ret1, ret2, doPrint=False) / div
            print(pstr + " TVD:", tvd)
            print("--------------")

            if label == vars["mech"]:
                mech_tvd = tvd

        # last_id= Exp.label_names.index(vars["mech"])
        # for lbid in range(last_id+1):
        #     var_set= Exp.label_names[0:lbid+1]
        #     pstr = "P(" + str(var_set) + ")"
        #     ret1 = compare_conditionals_within(Exp, generated_labels_full, feat, var_set , [], doPrint=False)
        #     ret2 = compare_conditionals_within(Exp, obs_real_dataset.detach().cpu().numpy().astype(int), feat, var_set, [],doPrint=False)
        #     print(pstr + " TVD:", calculate_TVD(ret1, ret2, doPrint=False))
        #     print("--------------")

        # print("Joint distribution, P(" + str(vars["compare"]) + ")")
        # ret1 = compare_conditionals_within(Exp, generated_labels_full, feat, vars["compare"], [], doPrint=False)
        # ret2 = compare_conditionals_within(Exp, obs_real_dataset.detach().cpu().numpy().astype(int), feat, vars["compare"], [],
        #                                    doPrint=False)
        # print("Joint dist TVD:", calculate_TVD(ret1, ret2, doPrint=False))
        # print("--------------")

    return mech_tvd





def get_observational_loss(Exp,label_generators,obs_bn, tvd_diff, kl_diff):
    feat= "feature"
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, Exp.label_names, Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, Exp.label_names)

    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, Exp.label_names, obs_bn[feat])
    query_str= "P(V)"
    true_dist_dict = get_intv_dist(Exp, Exp.label_names, dict({}), query_str)

    tvd, kl = match_with_true_dist(Exp, Exp.label_names, generated_labels_full, true_dist_dict, feat, doPrint=False)

    tvd_diff[query_str].append(tvd)
    kl_diff[query_str].append(kl)

    return tvd_diff, kl_diff


def get_expected_loss_interventions(Exp, cur_mech, label_generators, tvd_diff, kl_diff):
    feat="feature"

    for query in Exp.interv_queries:

        if cur_mech not in query["obs"]:
            continue

        # _, _, _, obs_dist1 = get_synthetic_dist(Exp, list(query["intervs"][0].keys()) , obs_bn[feat])
        compare_Var = list(query["intervs"][0].keys())  #getting the intervened variables
        query_str = getdoKey(compare_Var, dict({}))  # getting the scm saving file name
        obs_dist = get_intv_dist(Exp, compare_Var , dict({}), query_str) # getting the obs distribution of intv variables

        # {"obs": obs_vars, "intervs": key_val, "expr": intervention["expr"]}
        tvd_sum = 0
        kl_sum = 0
        for intv_key in query["intervs"]:

            query_string= getdoKey(query["obs"], intv_key)
            true_dist= get_intv_dist(Exp, query["obs"], intv_key, query_string)

            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, query["obs"], Exp.Synthetic_Sample_Size)
            generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, query["obs"])
            obs_tvd, obs_kl = match_with_true_dist(Exp, query["obs"], generated_labels_full, true_dist, feat, doPrint=False)  # get it from scm

            print(f'{intv_key}: tvd:{obs_tvd}, kl:{obs_kl} and tvd<={np.sqrt(0.5 * obs_kl)}')
            tvd_sum += obs_tvd * obs_dist[tuple(intv_key.values())]
            kl_sum += obs_kl * obs_dist[tuple(intv_key.values())]

        print(f'--->Average tvd:{tvd_sum}, kl:{kl_sum} and tvd<={np.sqrt(0.5 * kl_sum)}')
        tvd_diff[query["expr"]].append(round(tvd_sum, 4))
        kl_diff[query["expr"]].append(round(kl_sum, 4))


    return tvd_diff, kl_diff


def get_expected_loss_countefactuals(Exp, cur_mech, label_generators, obs_bn, tvd_diff, kl_diff):
    feat="feature"
    cfquery = Exp.cf_queries[0]

    if cur_mech not in cfquery["obs"]:
        return tvd_diff, kl_diff

    evidence_vars = [Exp.twin_map[lb] for lb in cfquery["evidence"][0].keys()]
    _, _, _, obs_dist = get_synthetic_dist(Exp, evidence_vars, obs_bn[feat])

    final_tvd=0
    final_kl=0

    for evidence in cfquery["evidence"]:

        n_samples = Exp.Synthetic_Sample_Size
        # posterior_label, posterior_latent, gumbel_noise = rejection_sampling(Exp, label_generators, n_samples, evidence,
        #                                                                      max_rejections=0, warn=100)
        all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp,
                                                                                                   label_generators,
                                                                                                   Exp.Synthetic_Sample_Size,
                                                                                                   evidence,
                                                                                                   max_rejections=0,
                                                                                                   warn=100)

        tvd_sum = 0
        kl_sum = 0
        for intv_key in cfquery["intervs"]:

            cf_all_labels_dict = get_generated_labels(Exp, label_generators, all_posterior_label, all_posterior_latent,
                                                      intv_key, cfquery["obs"], n_samples, gumbel_noise=all_gumbel_noise)
            cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, cfquery["obs"])

            true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence, cfquery["expr"], load_dist=True)
            # true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence)
            cf_tvd, cf_kl = match_with_true_dist(Exp, cfquery["obs"], cf_samples, true_cf_dist, feat,
                                                 doPrint=False)  # get it from scm

            tvd_sum += cf_tvd * obs_dist[tuple(intv_key.values())]
            kl_sum += cf_kl * obs_dist[tuple(intv_key.values())]

            print(f"CF query done for evidence:{evidence}, intv_key: {intv_key} ")
            print(tvd_sum, kl_sum)

        final_tvd += tvd_sum * obs_dist[tuple(evidence.values())]
        final_kl += kl_sum * obs_dist[tuple(evidence.values())]

    tvd_diff[cfquery["expr"]].append(final_tvd)
    kl_diff[cfquery["expr"]].append(final_kl)

    return tvd_diff, kl_diff


def evaluate_after_epochs(Exp, cur_mech, label_generators, dataset_dict, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():
        # observational tvd for each mechanisms so that we can notice that mechanism learning

        feat = "feature"
        all_generated_labels={}
        all_real_labels={}


        for interv_no, (intv_key, each_dataset) in enumerate(dataset_dict.items()):
            if cur_mech in dict(intv_key):  # its being intervened so no need to train.
                continue

            compare_Var =[]
            for lb in Exp.label_names:
                if lb in dict(intv_key):
                    continue
                compare_Var.append(lb)
                if lb==cur_mech:
                    break

            # compare_Var = Exp.train_mech_dict[cur_mech][interv_no]["compare"]
            obs_indices = [Exp.label_names.index(lb) for lb in compare_Var]
            current_real_label = each_dataset[:, obs_indices].type(torch.LongTensor).view(-1, len(obs_indices)).to(
                Exp.DEVICE)

            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, dict(intv_key), compare_Var,Exp.Synthetic_Sample_Size)
            generated_labels_full= map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)

            query_str = getdoKey(compare_Var, dict(intv_key))
            # _, _, _, true_dist_dict = get_synthetic_dist(Exp, compare_Var, Exp.true_bn[intv_key][feat])
            true_dist_dict = get_intv_dist(Exp, compare_Var, dict(intv_key), query_str)

            obs_tvd, obs_kl = match_with_true_dist(Exp, compare_Var, generated_labels_full,true_dist_dict, feat, doPrint=False)


            # query_str = "".join(x for x in compare_Var) + "|do" + "".join(x for x in intv_key.keys()) + "_" + "".join(str(x) for x in intv_key.values())

            tvd_diff[query_str].append(round(obs_tvd , 4))  #todo: fix it
            kl_diff[query_str].append(round(obs_kl , 4))  #Todo: fix it
            all_generated_labels[intv_key] = torch.tensor(generated_labels_full)
            all_real_labels[intv_key] = torch.tensor(current_real_label)


        # if (Exp.curr_epoochs <= 50 and (Exp.curr_epoochs + 1) % 5 == 0) or (Exp.curr_epoochs > 50 and (Exp.curr_epoochs + 1) % 15 == 0):
        # obs_bn = get_intv_dist(Exp, compare_Var, dict({}), query_str)

        tvd_diff, kl_diff = get_expected_loss_interventions(Exp, cur_mech, label_generators, tvd_diff, kl_diff)
            # tvd_diff, kl_diff= get_observational_loss(Exp, label_generators, obs_bn, tvd_diff, kl_diff)
            # tvd_diff, kl_diff= get_expected_loss_countefactuals(Exp, cur_mech, label_generators, obs_bn, tvd_diff, kl_diff)


        save_results(Exp, Exp.SAVED_PATH, all_generated_labels ,all_real_labels,
                     tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)

    # interventional queries evaluation
        # for query in Exp.interv_queries:
        #     if cur_mech == query["obs"][-1]:
        #         intv_tvd = compare_interventions(Exp, label_generators, query["obs"], query["interv"],doPrint=False)
        #         tvd_diff[query["expr"]].append(round(intv_tvd * 100, 4))



    for gen in label_generators:
        label_generators[gen].train()

    ll = -min(10, len(list(tvd_diff.values())[0]))
    # printing loss
    for dist in tvd_diff:
        print("###", dist, " loss%:", tvd_diff[dist][ll:])
    print(Exp.SAVED_PATH)

    return tvd_diff , kl_diff


